--- train_triplane.py	2023-04-06 03:37:58.653129514 +0000
+++ external_reference/train_triplane.py	2023-04-06 03:32:01.227232133 +0000
@@ -22,8 +22,8 @@
 import torch
 
 import dnnlib
-from training import training_loop
-from metrics import metric_main
+from external.eg3d.training import training_loop
+from external.stylegan.metrics import metric_main
 from torch_utils import training_stats
 from torch_utils import custom_ops
 
@@ -106,7 +106,7 @@
 
 def init_dataset_kwargs(data):
     try:
-        dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False)
+        dataset_kwargs = dnnlib.EasyDict(class_name='external.eg3d.training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False)
         dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # Subclass of training.dataset.Dataset.
         dataset_kwargs.resolution = dataset_obj.resolution # Be explicit about resolution.
         dataset_kwargs.use_labels = dataset_obj.has_labels # Be explicit about labels.
@@ -178,7 +178,7 @@
 @click.option('--gpc_reg_prob', help='Strength of swapping regularization. None means no generator pose conditioning, i.e. condition with zeros.', metavar='FLOAT',  type=click.FloatRange(min=0), required=False, default=0.5)
 @click.option('--gpc_reg_fade_kimg', help='Length of swapping prob fade', metavar='INT',  type=click.IntRange(min=0), required=False, default=1000)
 @click.option('--disc_c_noise', help='Strength of discriminator pose conditioning regularization, in standard deviations.', metavar='FLOAT',  type=click.FloatRange(min=0), required=False, default=0)
-@click.option('--sr_noise_mode', help='Type of noise for superresolution', metavar='STR',  type=click.Choice(['random', 'none']), required=False, default='none')
+@click.option('--sr_noise_mode', help='Type of noise for superresolution', metavar='STR',  type=click.Choice(['random', 'none', '3dnoise']), required=False, default='none')
 @click.option('--resume_blur', help='Enable to blur even on resume', metavar='BOOL',  type=bool, required=False, default=False)
 @click.option('--sr_num_fp16_res',    help='Number of fp16 layers in superresolution', metavar='INT', type=click.IntRange(min=0), default=4, required=False, show_default=True)
 @click.option('--g_num_fp16_res',    help='Number of fp16 layers in generator', metavar='INT', type=click.IntRange(min=0), default=0, required=False, show_default=True)
@@ -193,6 +193,26 @@
 @click.option('--reg_type', help='Type of regularization', metavar='STR',  type=click.Choice(['l1', 'l1-alt', 'monotonic-detach', 'monotonic-fixed', 'total-variation']), required=False, default='l1')
 @click.option('--decoder_lr_mul',    help='decoder learning rate multiplier.', metavar='FLOAT', type=click.FloatRange(min=0), default=1, required=False, show_default=True)
 
+#### added kwargs
+@click.option('--depth_clip', help='depth clip factor', metavar='FLOAT',
+              type=click.FloatRange(min=0), default=20., show_default=True)
+@click.option('--depth_scale', help='depth scale factor', metavar='FLOAT',
+              type=click.FloatRange(min=0), default=16., show_default=True)
+@click.option('--white_sky', help='sky pixels white',
+              metavar='BOOL', type=bool, default=False, show_default=True)
+@click.option('--ignore_LR_disp', help='ignore disp to LR discriminator input',
+              metavar='BOOL', type=bool, default=False, show_default=True)
+@click.option('--ignore_HR_disp', help='ignore disp to LR discriminator input',
+              metavar='BOOL', type=bool, default=True, show_default=True)
+@click.option('--plane_resolution', help='resolution of the stylegan planes',
+              metavar='INT', type=click.IntRange(min=0), default=256, show_default=True)
+@click.option('--ignore_w_to_upsampler', help='ignore w input to upsampler',
+              metavar='BOOL', type=bool, default=True, show_default=True)
+@click.option('--lambda_sky_pixel', help='sky penalty', metavar='FLOAT',
+              type=click.FloatRange(min=0), default=0., show_default=True)
+@click.option('--lambda_ramp_end', help='sky penalty ramp function', metavar='FLOAT',
+              type=click.FloatRange(min=0), default=1000, show_default=True)
+
 def main(**kwargs):
     """Train a GAN using the techniques described in the paper
     "Alias-Free Generative Adversarial Networks".
@@ -220,10 +240,10 @@
     opts = dnnlib.EasyDict(kwargs) # Command line arguments.
     c = dnnlib.EasyDict() # Main config dict.
     c.G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict())
-    c.D_kwargs = dnnlib.EasyDict(class_name='training.networks_stylegan2.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
+    c.D_kwargs = dnnlib.EasyDict(class_name='external.stylegan.training.networks_stylegan2_terrain.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
     c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
     c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0,0.99], eps=1e-8)
-    c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss')
+    c.loss_kwargs = dnnlib.EasyDict(class_name='external.eg3d.training.loss.StyleGAN2Loss')
     c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2)
 
     # Training set.
@@ -232,6 +252,10 @@
         raise click.ClickException('--cond=True requires labels specified in dataset.json')
     c.training_set_kwargs.use_labels = opts.cond
     c.training_set_kwargs.xflip = opts.mirror
+    # added training set kwargs
+    c.training_set_kwargs.depth_scale = opts.depth_scale
+    c.training_set_kwargs.depth_clip = opts.depth_clip
+    c.training_set_kwargs.white_sky = opts.white_sky
 
     # Hyperparameters & settings.
     c.num_gpus = opts.gpus
@@ -264,18 +288,18 @@
 
     # Base configuration.
     c.ema_kimg = c.batch_size * 10 / 32
-    c.G_kwargs.class_name = 'training.triplane.TriPlaneGenerator'
-    c.D_kwargs.class_name = 'training.dual_discriminator.DualDiscriminator'
+    c.G_kwargs.class_name = 'external.eg3d.training.triplane.TriPlaneGenerator'
+    c.D_kwargs.class_name = 'external.eg3d.training.dual_discriminator.DualDiscriminator'
     c.G_kwargs.fused_modconv_default = 'inference_only' # Speed up training by using regular convolutions instead of grouped convolutions.
     c.loss_kwargs.filter_mode = 'antialiased' # Filter mode for raw images ['antialiased', 'none', float [0-1]]
     c.D_kwargs.disc_c_noise = opts.disc_c_noise # Regularization for discriminator pose conditioning
 
     if c.training_set_kwargs.resolution == 512:
-        sr_module = 'training.superresolution.SuperresolutionHybrid8XDC'
+        sr_module = 'external.eg3d.training.superresolution.SuperresolutionHybrid8XDC'
     elif c.training_set_kwargs.resolution == 256:
-        sr_module = 'training.superresolution.SuperresolutionHybrid4X'
+        sr_module = 'external.eg3d.training.superresolution.SuperresolutionHybrid4X'
     elif c.training_set_kwargs.resolution == 128:
-        sr_module = 'training.superresolution.SuperresolutionHybrid2X'
+        sr_module = 'external.eg3d.training.superresolution.SuperresolutionHybrid2X'
     else:
         assert False, f"Unsupported resolution {c.training_set_kwargs.resolution}; make a new superresolution module"
     
@@ -329,10 +353,29 @@
             'avg_camera_radius': 1.7,
             'avg_camera_pivot': [0, 0, 0],
         })
+    elif opts.cfg == 'lhq':
+        rendering_options.update({
+            'depth_resolution': 64,
+            'depth_resolution_importance': 64,
+            'ray_start': 1.0,
+            'ray_end': opts.depth_scale, # dataset depth scale should match nerf far bound
+            'box_warp': 38.4,
+            'white_back': opts.white_sky,
+            'sample_deterministic': False, # add sampling noise for train
+            'y_clip': None, # don't clip y coordinate for training
+            'decay_start': None, # don't decay far content for training
+            'avg_camera_radius': 1.7, # not used
+            'avg_camera_pivot': [0, 0, 0], # not used
+        })
     else:
         assert False, "Need to specify config"
 
 
+    # added loss kwargs
+    c.loss_kwargs.ignore_LR_disp = opts.ignore_lr_disp
+    c.loss_kwargs.ignore_HR_disp = opts.ignore_hr_disp
+    c.loss_kwargs.lambda_sky_pixel = opts.lambda_sky_pixel
+    c.loss_kwargs.lambda_ramp_end = opts.lambda_ramp_end * 1000 # convert to kimg
 
     if opts.density_reg > 0:
         c.G_reg_interval = opts.density_reg_every
@@ -351,6 +394,10 @@
 
     c.G_kwargs.sr_kwargs = dnnlib.EasyDict(channel_base=opts.cbase, channel_max=opts.cmax, fused_modconv_default='inference_only')
 
+    # added model kwargs
+    c.G_kwargs.plane_resolution = opts.plane_resolution
+    c.G_kwargs.sr_kwargs.ignore_w = opts.ignore_w_to_upsampler
+
     c.loss_kwargs.style_mixing_prob = opts.style_mixing_prob
 
     # Augmentation.
